{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_"
      },
      "source": [
        "# Haiku Basics\n",
        "\n",
        "In this Colab, you will learn the basics of Haiku.\n",
        "\n",
        "**What and Why ?**\n",
        "\n",
        "[Haiku](https://github.com/deepmind/dm-haiku) is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations. Haiku is designed to make the common things we do such as managing model parameters and other model state simpler and similar in spirit to the [Sonnet](https://github.com/deepmind/sonnet) library that has been widely used across DeepMind. It preserves Sonnet's module-based programming model for state management while retaining access to JAX's function transformations. Haiku can be expected to compose with other libraries and work well with the rest of JAX.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "executionInfo": {
          "elapsed": 4044,
          "status": "ok",
          "timestamp": 1658932123202,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "lvPM0zvUC2Mi"
      },
      "outputs": [],
      "source": [
        "import haiku as hk\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jOGQXHTrC2Mj"
      },
      "source": [
        "### A first example with hk.transform\n",
        "\n",
        "As an initial introduction to Haiku, let us construct a linear module with weights and biases with custom initializations.\n",
        "\n",
        "Similar to Sonnet modules, Haiku modules are Python objects that hold references to their own parameters, other modules, and methods that apply functions on user inputs. On the other hand, since JAX operates on pure function transformations, Haiku modules cannot be instantiated verbatim. Rather, the modules need to be wrapped into pure function transformations. \n",
        "\n",
        "Haiku provides a simple function transformation, `hk.transform`, that turns functions that use these object-oriented, functionally \"impure\" modules into pure functions that can be used with JAX."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "executionInfo": {
          "elapsed": 54,
          "status": "ok",
          "timestamp": 1658932123346,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "TpoUknrxC2Mk"
      },
      "outputs": [],
      "source": [
        "class MyLinear1(hk.Module):\n",
        "\n",
        "  def __init__(self, output_size, name=None):\n",
        "    super().__init__(name=name)\n",
        "    self.output_size = output_size\n",
        "  \n",
        "  def __call__(self, x):\n",
        "    j, k = x.shape[-1], self.output_size\n",
        "    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))\n",
        "    w = hk.get_parameter(\"w\", shape=[j, k], dtype=x.dtype, init=w_init)\n",
        "    b = hk.get_parameter(\"b\", shape=[k], dtype=x.dtype, init=jnp.ones)\n",
        "    return jnp.dot(x, w) + b"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "executionInfo": {
          "elapsed": 55,
          "status": "ok",
          "timestamp": 1658932123496,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "Ewq6DBo5C2Ml"
      },
      "outputs": [],
      "source": [
        "def _forward_fn_linear1(x):\n",
        "  module = MyLinear1(output_size=2)\n",
        "  return module(x)\n",
        "\n",
        "forward_linear1 = hk.transform(_forward_fn_linear1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "27Z8D47zC2Ml"
      },
      "source": [
        "We see that the forward wrapper object now contains two methods, `init` and `apply`, that are used to initialize the variables and do forward inference on the module."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "executionInfo": {
          "elapsed": 56,
          "status": "ok",
          "timestamp": 1658932123646,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "HvEB2HjhC2Mm",
        "outputId": "9a672811-d0a8-4d92-f82e-5367a4a4687a"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Transformed(init=\u003cfunction without_state.\u003clocals\u003e.init_fn at 0x7f2e310354c0\u003e, apply=\u003cfunction without_state.\u003clocals\u003e.apply_fn at 0x7f2e31035550\u003e)"
            ]
          },
          "execution_count": 5,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "forward_linear1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Y9ScXfFC2Mn"
      },
      "source": [
        "Calling the init method will initialize the parameters of the network and return them, as can be seen below. The init method takes a `jax.random.PRNGKey` and a sample input (usually just some dummy values to tell the networks about the expected shapes).\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "executionInfo": {
          "elapsed": 1374,
          "status": "ok",
          "timestamp": 1658932125122,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "fh09CDJtC2Mo",
        "outputId": "6f834e1c-9426-44b4-b8ef-2d8e055b32e3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{'my_linear1': {'w': DeviceArray([[-0.30350363,  0.5123802 ],\n",
            "             [ 0.08009142, -0.3163005 ],\n",
            "             [ 0.6056666 ,  0.5820702 ]], dtype=float32), 'b': DeviceArray([1., 1.], dtype=float32)}}\n"
          ]
        }
      ],
      "source": [
        "dummy_x = jnp.array([[1., 2., 3.]])\n",
        "rng_key = jax.random.PRNGKey(42)\n",
        "\n",
        "params = forward_linear1.init(rng=rng_key, x=dummy_x)\n",
        "print(params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fB0B2wwhC2Mp"
      },
      "source": [
        "We can now use the params to apply the forward function to some inputs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "executionInfo": {
          "elapsed": 61,
          "status": "ok",
          "timestamp": 1658932125286,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "tktxIIypC2Mq",
        "outputId": "3f8331bb-1cec-4114-f437-ef3c64d044fd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Output 1 : [[2.6736789 2.6259897]]\n",
            "Output 2 (same as output 1): [[2.6736789 2.6259897]]\n",
            "Output 3 : [[3.820442 4.960439]\n",
            " [4.967205 7.294889]]\n"
          ]
        }
      ],
      "source": [
        "sample_x = jnp.array([[1., 2., 3.]])\n",
        "sample_x_2 = jnp.array([[4., 5., 6.], [7., 8., 9.]])\n",
        "\n",
        "output_1 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)\n",
        "# Outputs are identical for given inputs since the forward inference is non-stochastic.\n",
        "output_2 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)\n",
        "\n",
        "output_3 = forward_linear1.apply(params=params, x=sample_x_2, rng=rng_key)\n",
        "\n",
        "print(f'Output 1 : {output_1}')\n",
        "print(f'Output 2 (same as output 1): {output_2}')\n",
        "print(f'Output 3 : {output_3}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tz-F4CFfC2Mq"
      },
      "source": [
        "**Inference without random key**\n",
        "\n",
        "The module that we built is inherently non-stochastic. In that case, passing a random key to the apply method seems redundant. Haiku offers another transformation `hk.without_apply_rng` which can be further wrapped around our `hk.transform` method."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "executionInfo": {
          "elapsed": 58,
          "status": "ok",
          "timestamp": 1658932125450,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "Br3BO7k_C2Mr",
        "outputId": "ba7ff9f3-6ccd-4c43-a56b-5a8e6c56757e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Output without random key in forward pass \n",
            " [[2.6736789 2.6259897]]\n"
          ]
        }
      ],
      "source": [
        "forward_without_rng = hk.without_apply_rng(hk.transform(_forward_fn_linear1))\n",
        "params = forward_without_rng.init(rng=rng_key, x=sample_x)\n",
        "output = forward_without_rng.apply(x=sample_x, params=params)\n",
        "print(f'Output without random key in forward pass \\n {output_1}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_2dKaV7TC2Mr"
      },
      "source": [
        "We can also mutate the parameters and then do forward inference to generate a different output for the same inputs. This is what is done to apply gradient descent to our parameters while learning."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "executionInfo": {
          "elapsed": 56,
          "status": "ok",
          "timestamp": 1658932125601,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "Gds0RCNGC2Mr",
        "outputId": "59744ce3-b3b6-4813-d683-10c8d3e2133c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Mutated params \n",
            " : {'my_linear1': {'b': DeviceArray([2., 2.], dtype=float32), 'w': DeviceArray([[0.69649637, 1.5123801 ],\n",
            "             [1.0800915 , 0.6836995 ],\n",
            "             [1.6056666 , 1.5820701 ]], dtype=float32)}}\n",
            "Output with mutated params \n",
            " [[9.673679 9.62599 ]]\n"
          ]
        }
      ],
      "source": [
        "mutated_params = jax.tree.map(lambda x: x+1., params)\n",
        "print(f'Mutated params \\n : {mutated_params}')\n",
        "mutated_output = forward_without_rng.apply(x=sample_x, params=mutated_params)\n",
        "print(f'Output with mutated params \\n {mutated_output}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e_oMNX9NC2Ms"
      },
      "source": [
        "### Stateful Inference in Haiku\n",
        "\n",
        "For some modules you might want to maintain and carry over the internal state across function calls. Here, we demonstrate a simple example, where we declare a state variable `counter` within our Haiku transformation which gets updated on each call to the function. Note that we didn't explicitly instantiate this as a Haiku module (the same could be replicated as a hk module as shown earlier).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "executionInfo": {
          "elapsed": 129,
          "status": "ok",
          "timestamp": 1658932125838,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "IjZHAd15C2Ms",
        "outputId": "b0a9a360-408d-411e-8c3a-31ab3ed06eec"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Initial params:\n",
            "{'~': {'multiplier': DeviceArray([1.], dtype=float32)}}\n",
            "Initial state:\n",
            "{'~': {'counter': DeviceArray(1, dtype=int32)}}\n",
            "##########\n",
            "After 1 iterations:\n",
            "Output: [[6.]]\n",
            "State: {'~': {'counter': DeviceArray(2, dtype=int32)}}\n",
            "##########\n",
            "After 2 iterations:\n",
            "Output: [[7.]]\n",
            "State: {'~': {'counter': DeviceArray(3, dtype=int32)}}\n",
            "##########\n",
            "After 3 iterations:\n",
            "Output: [[8.]]\n",
            "State: {'~': {'counter': DeviceArray(4, dtype=int32)}}\n",
            "##########\n"
          ]
        }
      ],
      "source": [
        "def stateful_f(x):\n",
        "  counter = hk.get_state(\"counter\", shape=[], dtype=jnp.int32, init=jnp.ones)\n",
        "  multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)\n",
        "  hk.set_state(\"counter\", counter + 1)\n",
        "  output = x + multiplier * counter\n",
        "  return output\n",
        "\n",
        "stateful_forward = hk.without_apply_rng(hk.transform_with_state(stateful_f))\n",
        "sample_x = jnp.array([[5., ]])\n",
        "params, state = stateful_forward.init(x=sample_x, rng=rng_key)\n",
        "print(f'Initial params:\\n{params}\\nInitial state:\\n{state}')\n",
        "print('##########')\n",
        "for i in range(3):\n",
        "  output, state = stateful_forward.apply(params, state, x=sample_x)\n",
        "  print(f'After {i+1} iterations:\\nOutput: {output}\\nState: {state}')\n",
        "  print('##########')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_LRXQQSpC2Mt"
      },
      "source": [
        "### Built-in Haiku nets and nested modules\n",
        "\n",
        "The usual networks we use such as MLP, Convnets etc. are [defined already in Haiku](https://dm-haiku.readthedocs.io/en/latest/api.html#common-modules) and we can compose those modules to construct our custom Haiku Module.\n",
        "\n",
        "Look at the params dictionary to see how the params are nested in the same way as the modules are nested within our custom Haiku module.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "executionInfo": {
          "elapsed": 656,
          "status": "ok",
          "timestamp": 1658932126592,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "x36JAuQOC2Mt",
        "outputId": "9069b34d-ee71-4c4e-ad41-18a852d72d90"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'custom_linear/~/hk_internal_linear/~/linear_0': {'b': DeviceArray([0., 0.], dtype=float32),\n",
              "  'w': DeviceArray([[ 1.51595   , -0.23353337]], dtype=float32)},\n",
              " 'custom_linear/~/hk_internal_linear/~/linear_1': {'b': DeviceArray([0., 0., 0.], dtype=float32),\n",
              "  'w': DeviceArray([[-0.22075887, -0.27375957,  0.5931483 ],\n",
              "               [ 0.7818068 ,  0.72626334, -0.6860752 ]], dtype=float32)},\n",
              " 'custom_linear/~/old_linear': {'b': DeviceArray([1., 1.], dtype=float32),\n",
              "  'w': DeviceArray([[ 0.28584382,  0.31626168],\n",
              "               [ 0.2335775 , -0.4827032 ],\n",
              "               [-0.14647584, -0.7185701 ]], dtype=float32)}}"
            ]
          },
          "execution_count": 11,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# See: https://dm-haiku.readthedocs.io/en/latest/api.html#common-modules\n",
        "\n",
        "class MyModuleCustom(hk.Module):\n",
        "  def __init__(self, output_size=2, name='custom_linear'):\n",
        "    super().__init__(name=name)\n",
        "    self._internal_linear_1 = hk.nets.MLP(output_sizes=[2, 3], name='hk_internal_linear')\n",
        "    self._internal_linear_2 = MyLinear1(output_size=output_size, name='old_linear')\n",
        "\n",
        "  def __call__(self, x):\n",
        "    return self._internal_linear_2(self._internal_linear_1(x))\n",
        "\n",
        "def _custom_forward_fn(x):\n",
        "  module = MyModuleCustom()\n",
        "  return module(x)\n",
        "\n",
        "custom_forward_without_rng = hk.without_apply_rng(hk.transform(_custom_forward_fn))\n",
        "params = custom_forward_without_rng.init(rng=rng_key, x=sample_x)\n",
        "params"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n2tIinMDC2Mu"
      },
      "source": [
        "### RNG Keys with hk.next_rng_key()\n",
        "\n",
        "The modules that we saw earlier were all non-stochastic. Below we show how to sample random numbers to do stochastic inference.\n",
        "\n",
        "Haiku offers a trivial model for working with random numbers. Within a transformed function, `hk.next_rng_key()` returns a unique rng key.\n",
        "These unique keys are deterministically derived from an initial random key passed into the top-level transformed function, and are thus safe to use with JAX program transformations.\n",
        "\n",
        "Let us define a simple haiku function where we generate two random samples. Note that the next_rng_keys are determined from the initial random key passed to the apply method of the top-level transformed function. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "executionInfo": {
          "elapsed": 57,
          "status": "ok",
          "timestamp": 1658932170511,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "PUucdMsuC2Mu",
        "outputId": "fb9ffc34-c1bf-4c0c-ed60-9fedcbf2e04d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "INIT:\n",
            "Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))\n",
            "APPLY:\n",
            "Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))\n"
          ]
        }
      ],
      "source": [
        "class HkRandom2(hk.Module):\n",
        "  def __init__(self, rate=0.5):\n",
        "    super().__init__()\n",
        "    self.rate = rate\n",
        "  \n",
        "  def __call__(self, x):\n",
        "    key1 = hk.next_rng_key()\n",
        "    return jax.random.bernoulli(key1, 1.0 - self.rate, shape=x.shape)\n",
        "\n",
        "\n",
        "class HkRandomNest(hk.Module):\n",
        "  def __init__(self, rate=0.5):\n",
        "    super().__init__()\n",
        "    self.rate = rate\n",
        "    self._another_random_module = HkRandom2()\n",
        "\n",
        "  def __call__(self, x):\n",
        "    key2 = hk.next_rng_key()\n",
        "    p1 = self._another_random_module(x)\n",
        "    p2 = jax.random.bernoulli(key2, 1.0 - self.rate, shape=x.shape)\n",
        "    print(f'Bernoullis are  : {p1, p2}')\n",
        "\n",
        "# Note that the modules that are stochastic cannot be wrapped with hk.without_apply_rng()\n",
        "forward = hk.transform(lambda x: HkRandomNest()(x))\n",
        "\n",
        "x = jnp.array(1.)\n",
        "print(\"INIT:\")\n",
        "params = forward.init(rng_key, x=x)\n",
        "print(\"APPLY:\")\n",
        "prediction = forward.apply(params, x=x, rng=rng_key)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KuusEoaVEDEI"
      },
      "source": [
        "**Note  that this means that passing the same random key to multiple calls of the apply function will generate the same stochastic results!**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "executionInfo": {
          "elapsed": 57,
          "status": "ok",
          "timestamp": 1658932230031,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "ekkrgZEjEHKA",
        "outputId": "706fef79-dbdf-4d13-cedf-4eda1cb96cc1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))\n",
            "Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))\n",
            "Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))\n"
          ]
        }
      ],
      "source": [
        "for _ in range(3):\n",
        "  forward.apply(params, x=x, rng=rng_key)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0_JFIJK2Edn-"
      },
      "source": [
        "Make sure to split off new RNG keys to get different stochastic behavior across `apply` calls, and to never re-use an RNG key. (For a more comprehensive explanation of how to handle random state in JAX, check out this RNG tutorial: https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html.)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "executionInfo": {
          "elapsed": 56,
          "status": "ok",
          "timestamp": 1658932268073,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "Gsv5XFc7ERF_",
        "outputId": "2b5fbb0b-927e-492f-ee3b-1e4aa8e385b2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Bernoullis are  : (DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool))\n",
            "Bernoullis are  : (DeviceArray(False, dtype=bool), DeviceArray(True, dtype=bool))\n",
            "Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))\n"
          ]
        }
      ],
      "source": [
        "for _ in range(3):\n",
        "  rng_key, apply_rng_key = jax.random.split(rng_key)\n",
        "  forward.apply(params, x=x, rng=apply_rng_key)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XfG7m6sXFzZd"
      },
      "source": [
        "Haiku also provides `hk.PRNGSequence` which returns an iterator of random keys."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "executionInfo": {
          "elapsed": 57,
          "status": "ok",
          "timestamp": 1658932620703,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "zF7QIcznFlFg",
        "outputId": "e816985a-8aa9-436e-a88b-d70625fb19ca"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(True, dtype=bool))\n",
            "Bernoullis are  : (DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool))\n",
            "Bernoullis are  : (DeviceArray(False, dtype=bool), DeviceArray(True, dtype=bool))\n"
          ]
        }
      ],
      "source": [
        "rng_sequence = hk.PRNGSequence(rng_key)\n",
        "for _ in range(3):\n",
        "  forward.apply(params, x=x, rng=next(rng_sequence))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "name": "basics.ipynb",
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
